from geesibling.adapters.pytorch.pipeline.megatron import mpu



def patch_config(config):
    """s

    Args:
        config: 原始 config 对象 (比如 LlamaConfig)
        kwargs: 需要添加的属性和值
                e.g. pp_rank=0, pp_size=4, pre_process=True, post_process=False
    """
    pp_rank = mpu.get_pipeline_model_parallel_rank()
    pre_process = mpu.is_pipeline_first_stage()
    post_process = mpu.is_pipeline_last_stage()
    pp_size = mpu.get_pipeline_model_parallel_world_size()

    setattr(config, "pp_rank", pp_rank)
    setattr(config, "pre_process", pre_process)
    setattr(config, "post_process", post_process)
    setattr(config, "pp_size", pp_size)
    return config